package com.linln.common.config;

import com.linln.common.constant.StatusConst;
import com.linln.common.utils.ReflectUtils;
import org.apache.commons.beanutils.MethodUtils;
import org.hibernate.cfg.ImprovedNamingStrategy;
import org.hibernate.criterion.CriteriaSpecification;
import org.hibernate.query.internal.NativeQueryImpl;
import org.hibernate.transform.ResultTransformer;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.PageImpl;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Sort;
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.data.jpa.repository.Modifying;
import org.springframework.data.jpa.repository.Query;
import org.springframework.data.repository.NoRepositoryBean;
import org.springframework.transaction.annotation.Transactional;

import javax.persistence.EntityManager;
import java.sql.Timestamp;
import java.util.*;

/**
 * @author taofucheng
 * @date 2018/8/14
 */
@NoRepositoryBean
public interface BaseRepository<T, ID> extends JpaRepository<T, ID> {
    /**
     * 批量更新数据状态
     * #{#entityName} 实体类对象
     *
     * @param status 状态
     * @param id     ID列表
     * @return 更新数量
     */
    @Modifying
    @Transactional
    @Query("update #{#entityName} set status = ?1  where id in ?2 and status <> " + StatusConst.DELETE)
    public Integer updateStatus(Byte status, List<ID> id);

    /**
     * 完整保存
     *
     * @param entity
     * @param <S>
     * @return
     */
    @Transactional
    <S extends T> S saveFully(S entity);

    /**
     * 执行指定DML语句
     *
     * @param sql
     * @param params
     * @param em
     */
    default void doExecute(String sql, Map<String, Object> params, EntityManager em) {
        javax.persistence.Query q = em.createNativeQuery(sql);
        if (params != null && params.size() > 0) {
            for (String key : params.keySet()) {
                q.setParameter(key, params.get(key));
            }
        }
        q.executeUpdate();
    }

    /**
     * 将指定的sql进行count(*)包裹后查询对应的值
     *
     * @param countSql
     * @return
     */
    default int queryCount(String countSql, Map<String, Object> params, EntityManager em) {
        javax.persistence.Query q = em.createNativeQuery(countSql);
        if (params != null && params.size() > 0) {
            for (String key : params.keySet()) {
                q.setParameter(key, params.get(key));
            }
        }
        try {
            Number n = (Number) q.getSingleResult();
            return n == null ? 0 : n.intValue();
        } catch (javax.persistence.NoResultException e) {
            return 0;
        }
    }

    /**
     * 查询指定的SQL，获取分页数据，会查询总记录数
     *
     * @param page
     * @param sql
     * @param em
     * @param params
     * @return
     */
    default <T> Page<T> doQuery(Pageable page, String sql, Class<T> entityClass, Map<String, Object> params,
                                EntityManager em) {
        int total = queryCount("select count(*) from (" + sql + ") t", params, em);
        // 处理翻页越界问题
        int totalPages = total / page.getPageSize() + (total % page.getPageSize() > 0 ? 1 : 0);
        if (totalPages < page.getPageNumber()) {
            return new PageImpl<T>(new ArrayList<>(), page, total);
        }
        if (total < 1) {
            return new PageImpl<T>(new ArrayList<>(), page, total);
        }
        List<T> entities = doQueryList(page, sql, entityClass, params, em);
        return new PageImpl<T>(entities, page, total);
    }

    /**
     * 查询指定的列表，只返回数据，不查询记录总数
     *
     * @param page
     * @param sql
     * @param entityClass
     * @param params
     * @param em
     * @return
     */
    @SuppressWarnings("unchecked")
    default <T> List<T> doQueryList(Pageable page, String sql, Class<T> entityClass, Map<String, Object> params,
                                    EntityManager em) {
        List<Map<String, Object>> rows = doQueryMap(page, sql, params, em);
        List<T> entities = new ArrayList<>();
        if (rows != null) {
            if (Map.class.isAssignableFrom(entityClass)) {
                entities = (List<T>) rows;
            } else {
                for (Map<String, Object> row : rows) {
                    T entity = ReflectUtils.mapToEntity(row, entityClass);
                    if (entity == null) {
                        continue;
                    }
                    entities.add(entity);
                }
            }
        }
        return entities;
    }

    /**
     * 查询SQL，并用Map结构返回结果
     *
     * @param page
     * @param sql
     * @param params
     * @param em
     * @return
     */
    @SuppressWarnings({"unchecked", "deprecation"})
    default List<Map<String, Object>> doQueryMap(Pageable page, String sql, Map<String, Object> params,
                                                 EntityManager em) {
        // 排序
        if (page != null && page.getSort() != null) {// 添加order
            Iterator<Sort.Order> os = page.getSort().iterator();
            while (os.hasNext()) {
                Sort.Order o = os.next();
                String orderBy = ImprovedNamingStrategy.INSTANCE.classToTableName(o.getProperty()) + " "
                        + o.getDirection().name();
                sql += " order by " + orderBy;
                break;
            }
        }
        javax.persistence.Query q = em.createNativeQuery(sql);
        try {
            MethodUtils.invokeMethod(q, "setResultTransformer", CriteriaSpecification.ALIAS_TO_ENTITY_MAP);
        } catch (Exception ignore) {
        }
        try {//将结果使用LinkedHashMap输出，保证字段的顺序
            ((NativeQueryImpl<?>) q).setResultTransformer(new ResultTransformer() {
                @Override
                public Object transformTuple(Object[] tuple, String[] aliases) {
                    Map<String, Object> one = new LinkedHashMap<>();
                    for (int i = 0; i < aliases.length; i++) {
                        Object val = tuple[i];
                        if (val instanceof Timestamp) {
                            val = ((Timestamp) val).toLocalDateTime();
                        }
                        one.put(aliases[i], val);
                    }
                    return one;
                }

                @Override
                public List transformList(List collection) {
                    return collection;
                }
            });
        } catch (Exception e) {
            e.printStackTrace();
        }
        if (params != null && params.size() > 0) {
            for (String key : params.keySet()) {
                q.setParameter(key, params.get(key));
            }
        }
        int from = (int) page.getOffset();
        if (from < 0) {
            from = 0;
        }
        List<Map<String, Object>> rows = (List<Map<String, Object>>) q.setFirstResult(from)
                .setMaxResults(page.getPageSize()).getResultList();
        return rows;
    }
}
